import torch


def train_dn_one_iter(train_cnn_predictions_main, train_actual_output, all_models, optimizers, criterion, L1_WEIGHT,
                      cfg, action="average"):
    """
    Train the DN and also returns the gradients wrt inputs to pass back to the CNN
    Args:
        train_cnn_predictions:
        train_actual_output:
        all_models:
        optimizers:
        criterion:
        L1_WEIGHT:

    Returns:

    """
    gradients_for_CNN = []
    final_outputs = []
    num_true_label = train_actual_output.shape[1]
    for each_true_label_index in range(num_true_label):
        this_train_y = train_actual_output[:, each_true_label_index]
        if each_true_label_index == 0:
            all_other_actual_labels = train_actual_output[:, 1:]
        elif each_true_label_index == num_true_label - 1:
            all_other_actual_labels = train_actual_output[:, :-1]
        else:
            all_other_actual_labels = torch.cat((train_actual_output[:, :each_true_label_index],
                                                 train_actual_output[:, each_true_label_index + 1:]), 1)
        train_cnn_predictions = train_cnn_predictions_main.detach().clone()
        train_cnn_predictions.requires_grad = True
        this_train_x = torch.cat((all_other_actual_labels, train_cnn_predictions), 1)
        optimizers[each_true_label_index].zero_grad()
        outputs = all_models[each_true_label_index](this_train_x)
        final_outputs.append(outputs)
        loss = criterion(torch.squeeze(outputs), this_train_y)
        # Compute L1 loss component
        l1_parameters = [parameter.view(-1) for parameter in all_models[each_true_label_index].parameters()]
        l1 = L1_WEIGHT * all_models[each_true_label_index].compute_l1_loss(torch.cat(l1_parameters))
        # Add L1 loss component
        loss += l1
        loss.backward()
        gradient_wrt_train_cnn_predictions = train_cnn_predictions.grad
        gradients_for_CNN.append(gradient_wrt_train_cnn_predictions)
        optimizers[each_true_label_index].step()
    tensor = torch.stack(gradients_for_CNN)
    if action.strip().lower() == "average":
        return torch.mean(tensor, dim=0), all_models
    elif action.strip().lower() == "sum":
        return torch.sum(tensor, dim=0), all_models


def train_epoch_dn(
        train_loader,
        model,
        dn_models,
        dn_optimizers,
        dn_criterion,
        cfg,
):
    """
    Perform the video training for one epoch.
    Args:
        train_loader (loader): video training loader.
        model (model): the video model to train.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Enable train mode.
    model.eval()
    data_size = len(train_loader)

    for cur_iter, (inputs, labels, _, meta) in enumerate(train_loader):
        # Transfer the data to the current GPU device.
        if cfg.NUM_GPUS:
            if isinstance(inputs, (list,)):
                for i in range(len(inputs)):
                    inputs[i] = inputs[i].cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)
            labels = labels.cuda()
            for key, val in meta.items():
                if isinstance(val, (list,)):
                    for i in range(len(val)):
                        val[i] = val[i].cuda(non_blocking=True)
                else:
                    meta[key] = val.cuda(non_blocking=True)

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
                if cfg.DETECTION.ENABLE:
                    preds = model(inputs, meta["boxes"])
                else:
                    preds = model(inputs)
        # Train the DN
        gradients_from_dn, dn_models = train_dn_one_iter(preds.detach().float(), labels.float(), dn_models,
                                                         dn_optimizers, dn_criterion, cfg.JOINT_LEARNING.L1_WEIGHT,
                                                         cfg, action="average")
